import numpy as np
import pandas as pd
import re
import matplotlib.pyplot as plt
from scipy.optimize import minimize
from tqdm import tqdm
from joblib import Parallel, delayed
import logging
import time

# Setup logging
logging.basicConfig(level=logging.INFO, filename="symbolic_fit.log", filemode="w")

# Extended primes list (up to 1000)
PRIMES = [
    2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71,
    73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151,
    157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233,
    239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317,
    331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397, 401, 409, 419,
    421, 431, 433, 439, 443, 449, 457, 461, 463, 467, 479, 487, 491, 499, 503,
    509, 521, 523, 541, 547, 557, 563, 569, 571, 577, 587, 593, 599, 601, 607,
    613, 617, 619, 631, 641, 643, 647, 653, 659, 661, 673, 677, 683, 691, 701,
    709, 719, 727, 733, 739, 743, 751, 757, 761, 769, 773, 787, 797, 809, 811,
    821, 823, 827, 829, 839, 853, 857, 859, 863, 877, 881, 883, 887, 907, 911,
    919, 929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997
]

phi = (1 + np.sqrt(5)) / 2
fib_cache = {}

def fib_real(n):
    if n in fib_cache:
        return fib_cache[n]
    from math import cos, pi, sqrt
    phi_inv = 1 / phi
    if n > 100:
        return 0.0
    term1 = phi**n / sqrt(5)
    term2 = (phi_inv**n) * cos(pi * n)
    result = term1 - term2
    fib_cache[n] = result
    return result

def D(n, beta, r=1.0, k=1.0, Omega=1.0, base=2, scale=1.0):
    Fn_beta = fib_real(n + beta)
    idx = int(np.floor(n + beta) + len(PRIMES)) % len(PRIMES)
    Pn_beta = PRIMES[idx]
    dyadic = base ** (n + beta)
    val = scale * phi * Fn_beta * dyadic * Pn_beta * Omega
    if n > 1000:
        val *= np.log(n) / np.log(1000)
    return np.sqrt(max(val, 1e-30)) * (r ** k)

def invert_D(value, r=1.0, k=1.0, Omega=1.0, base=2, scale=1.0, max_n=500, steps=200):
    candidates = []
    log_val = np.log10(max(abs(value), 1e-30))
    max_n = min(5000, max(500, int(300 * log_val)))
    steps = 100 if log_val < 3 else 200
    n_values = np.logspace(0, np.log10(max_n), steps) if log_val > 3 else np.linspace(0, max_n, steps)
    scale_factors = np.logspace(log_val - 5, log_val + 5, num=20)
    try:
        for n in n_values:
            for beta in np.linspace(0, 1, 10):
                for dynamic_scale in scale_factors:
                    for r_local in [0.5, 1.0]:
                        for k_local in [0.5, 1.0]:
                            val = D(n, beta, r_local, k_local, Omega, base, scale * dynamic_scale)
                            diff = abs(val - value)
                            candidates.append((diff, n, beta, dynamic_scale, r_local, k_local))
        candidates = sorted(candidates, key=lambda x: x[0])[:10]
        best = candidates[0]
        emergent_uncertainty = np.std([D(n, beta, r, k, Omega, base, scale * s) for _, n, beta, s, r, k in candidates])
        return best[1], best[2], best[3], emergent_uncertainty, best[4], best[5]
    except Exception as e:
        logging.error(f"invert_D failed for value {value}: {e}")
        return None

def parse_codata_ascii(filename):
    constants = []
    pattern = re.compile(r"^\s*(.*?)\s{2,}([0-9Ee\+\-\.]+(?:\.\.\.)?)\s+([0-9Ee\+\-\.]+|exact)\s+(\S.*)")
    with open(filename, "r") as f:
        for line in f:
            if line.startswith("Quantity") or line.strip() == "" or line.startswith("-"):
                continue
            m = pattern.match(line)
            if m:
                name, value_str, uncert_str, unit = m.groups()
                try:
                    value = float(value_str.replace("...", ""))
                    uncertainty = 0.0 if uncert_str == "exact" else float(uncert_str.replace("...", ""))
                    constants.append({
                        "name": name.strip(),
                        "value": value,
                        "uncertainty": uncertainty,
                        "unit": unit.strip()
                    })
                except Exception as e:
                    logging.warning(f"Failed parsing line: {line.strip()} - {e}")
                    continue
    return pd.DataFrame(constants)

def check_physical_consistency(df_results):
    bad_data = []
    relations = [
        ('Planck constant', 'reduced Planck constant', lambda x, y: x['scale'] / y['scale'] - 2 * np.pi, 0.1, 'scale ratio vs. 2π'),
        ('proton mass', 'proton-electron mass ratio', lambda x, y: x['n'] - y['n'] - np.log10(1836), 0.5, 'n difference vs. log(proton-electron ratio)'),
        ('molar mass of carbon-12', 'Avogadro constant', lambda x, y: x['scale'] / y['scale'] - 12, 0.1, 'scale ratio vs. 12'),
        ('elementary charge', 'electron volt', lambda x, y: x['n'] - y['n'], 0.5, 'n difference vs. 0'),
        ('Rydberg constant', 'fine-structure constant', lambda x, y: x['n'] - 2 * y['n'] - np.log10(2 * np.pi), 0.5, 'n difference vs. log(2π)'),
        ('Boltzmann constant', 'electron volt-kelvin relationship', lambda x, y: x['scale'] / y['scale'] - 1, 0.1, 'scale ratio vs. 1'),
        ('Stefan-Boltzmann constant', 'second radiation constant', lambda x, y: x['n'] + 4 * y['n'] - np.log10(15 * 299792458**2 / (2 * np.pi**5)), 1.0, 'n relationship vs. c and k_B'),
        ('Fermi coupling constant', 'weak mixing angle', lambda x, y: x['scale'] / (y['value']**2 / np.sqrt(2)), 0.1, 'scale vs. sin²θ_W/√2'),
        ('tau mass energy equivalent in MeV', 'tau energy equivalent', lambda x, y: x['n'] - y['n'], 0.5, 'n difference vs. 0'),
    ]
    for name1, name2, check_func, threshold, reason in relations:
        if name1 in df_results['name'].values and name2 in df_results['name'].values:
            fit1 = df_results[df_results['name'] == name1][['n', 'beta', 'scale', 'value']].iloc[0]
            fit2 = df_results[df_results['name'] == name2][['n', 'beta', 'scale', 'value']].iloc[0]
            diff = abs(check_func(fit1, fit2))
            if diff > threshold:
                bad_data.append({
                    'name': name2,
                    'value': df_results[df_results['name'] == name2]['value'].iloc[0],
                    'reason': f'Model {reason} inconsistent ({diff:.2e} > {threshold:.2e})'
                })
    return bad_data

def fit_single_constant(row, r, k, Omega, base, scale, max_n, steps, error_threshold, median_uncertainties):
    start_time = time.time()
    val = row['value']
    if val <= 0 or val > 1e50:
        logging.warning(f"Skipping {row['name']}: Invalid value {val}")
        return None
    try:
        result = invert_D(val, r, k, Omega, base, scale, max_n, steps)
        if result is None:
            logging.error(f"invert_D returned None for {row['name']}")
            return None
        n, beta, dynamic_scale, emergent_uncertainty, r_local, k_local = result
        approx = D(n, beta, r_local, k_local, Omega, base, scale * dynamic_scale)
        error = abs(val - approx)
        rel_error = error / max(abs(val), 1e-30)
        log_val = np.log10(max(abs(val), 1e-30))
        # Bad data detection
        bad_data = False
        bad_data_reason = []
        # Uncertainty check
        if row['uncertainty'] is not None and row['uncertainty'] > 0:
            rel_uncert = row['uncertainty'] / max(abs(val), 1e-30)
            if rel_uncert > 0.5:
                bad_data = True
                bad_data_reason.append(f"High relative uncertainty ({rel_uncert:.2e} > 0.5)")
            if abs(row['uncertainty'] - emergent_uncertainty) > 1.5 * emergent_uncertainty or \
               abs(row['uncertainty'] - emergent_uncertainty) / max(emergent_uncertainty, 1e-30) > 1.0:
                bad_data = True
                bad_data_reason.append(f"Uncertainty deviates from emergent ({row['uncertainty']:.2e} vs. {emergent_uncertainty:.2e})")
        # Outlier check
        if error > error_threshold and row['uncertainty'] is not None:
            bin_idx = min(int((log_val + 50) // 10), len(median_uncertainties) - 1)
            median_uncert = median_uncertainties[bin_idx] if bin_idx >= 0 else np.median(df['uncertainty'].dropna())
            if row['uncertainty'] > 0 and row['uncertainty'] < median_uncert:
                bad_data = True
                bad_data_reason.append("High error with low uncertainty")
            if row['uncertainty'] > 0 and error > 10 * row['uncertainty']:
                bad_data = True
                bad_data_reason.append("Error exceeds 10x uncertainty")
        # Clear fib_cache after each constant
        global fib_cache
        fib_cache.clear()
        if time.time() - start_time > 5:  # Timeout after 5 seconds
            logging.warning(f"Timeout for {row['name']}: {time.time() - start_time:.2f} seconds")
            return None
        return {
            "name": row['name'],
            "value": val,
            "unit": row['unit'],
            "n": n,
            "beta": beta,
            "approx": approx,
            "error": error,
            "rel_error": rel_error,
            "uncertainty": row['uncertainty'],
            "emergent_uncertainty": emergent_uncertainty,
            "r_local": r_local,
            "k_local": k_local,
            "scale": dynamic_scale,
            "bad_data": bad_data,
            "bad_data_reason": "; ".join(bad_data_reason) if bad_data_reason else ""
        }
    except Exception as e:
        logging.error(f"Failed inversion for {row['name']}: {e}")
        return None

def symbolic_fit_all_constants(df, r=1.0, k=1.0, Omega=1.0, base=2, scale=1.0, max_n=500, steps=200):
    logging.info("Starting symbolic fit for all constants...")
    # Preliminary fit to get error threshold
    results = Parallel(n_jobs=-1, backend='loky')(
        delayed(fit_single_constant)(row, r, k, Omega, base, scale, max_n, steps, np.inf, {})
        for _, row in df.iterrows()
    )
    results = [r for r in results if r is not None]
    df_results = pd.DataFrame(results)
    error_threshold = np.percentile(df_results['error'], 95) if not df_results.empty else np.inf
    # Calculate median uncertainties per magnitude bin
    log_values = np.log10(df_results['value'].abs().clip(1e-30))
    try:
        bins = pd.qcut(log_values, 5, duplicates='drop')
    except Exception as e:
        logging.warning(f"pd.qcut failed: {e}. Using default binning.")
        bins = pd.cut(log_values, 5)
    median_uncertainties = {}
    for bin in bins.unique():
        mask = bins == bin
        median_uncert = df_results[mask]['uncertainty'].median()
        if not np.isnan(median_uncert):
            bin_idx = min(int((bin.mid + 50) // 10), 10)
            median_uncertainties[bin_idx] = median_uncert
    # Final fit with error threshold and median uncertainties
    results = Parallel(n_jobs=-1, backend='loky')(
        delayed(fit_single_constant)(row, r, k, Omega, base, scale, max_n, steps, error_threshold, median_uncertainties)
        for _, row in df.iterrows()
    )
    results = [r for r in results if r is not None]
    df_results = pd.DataFrame(results)
    # Physical consistency check
    bad_data_physical = check_physical_consistency(df_results)
    for bad in bad_data_physical:
        df_results.loc[df_results['name'] == bad['name'], 'bad_data'] = True
        df_results.loc[df_results['name'] == bad['name'], 'bad_data_reason'] = (
            df_results.loc[df_results['name'] == bad['name'], 'bad_data_reason'] + "; " + bad['reason']
        ).str.strip("; ")
    # Uncertainty outlier check using IQR
    if not df_results.empty:
        for bin in bins.unique():
            mask = bins == bin
            if df_results[mask]['uncertainty'].notnull().any():
                uncertainties = df_results[mask]['uncertainty'].dropna()
                q1, q3 = np.percentile(uncertainties, [25, 75])
                iqr = q3 - q1
                outlier_threshold = q3 + 3 * iqr
                df_results.loc[mask & (df_results['uncertainty'] > outlier_threshold), 'bad_data'] = True
                df_results.loc[mask & (df_results['uncertainty'] > outlier_threshold), 'bad_data_reason'] = (
                    df_results['bad_data_reason'] + "; Uncertainty outlier"
                ).str.strip("; ")
    logging.info("Symbolic fit completed.")
    return df_results

def total_error(params, df):
    r, k, Omega, base, scale = params
    try:
        df_fit = symbolic_fit_all_constants(df, r=r, k=k, Omega=Omega, base=base, scale=scale, max_n=500, steps=200)
        threshold = np.percentile(df_fit['error'], 95)
        filtered = df_fit[df_fit['error'] <= threshold]
        rel_err = ((filtered['value'] - filtered['approx']) / filtered['value'])**2
        return rel_err.sum()
    except Exception as e:
        logging.error(f"total_error failed: {e}")
        return np.inf

if __name__ == "__main__":
    print("Parsing CODATA constants from allascii.txt...")
    start_time = time.time()
    codata_df = parse_codata_ascii("allascii.txt")
    print(f"Parsed {len(codata_df)} constants in {time.time() - start_time:.2f} seconds.")

    # Use a smaller subset for optimization
    subset_df = codata_df.head(20)
    init_params = [1.0, 1.0, 1.0, 2.0, 1.0]
    bounds = [(1e-5, 10), (1e-5, 10), (1e-5, 10), (1.5, 10), (1e-5, 100)]

    print("Optimizing symbolic model parameters...")
    start_time = time.time()
    res = minimize(total_error, init_params, args=(subset_df,), bounds=bounds, method='L-BFGS-B', options={'maxiter': 50})
    r_opt, k_opt, Omega_opt, base_opt, scale_opt = res.x
    print(f"Optimization complete in {time.time() - start_time:.2f} seconds. Found parameters:\nr = {r_opt:.6f}, k = {k_opt:.6f}, Omega = {Omega_opt:.6f}, base = {base_opt:.6f}, scale = {scale_opt:.6f}")

    print("Fitting symbolic dimensions to all constants...")
    start_time = time.time()
    fitted_df = symbolic_fit_all_constants(codata_df, r=r_opt, k=k_opt, Omega=Omega_opt, base=base_opt, scale=scale_opt, max_n=500, steps=200)
    fitted_df_sorted = fitted_df.sort_values("error")
    print(f"Fitting complete in {time.time() - start_time:.2f} seconds.")

    print("\nTop 20 best symbolic fits:")
    print(fitted_df_sorted.head(20)[['name', 'value', 'unit', 'n', 'beta', 'approx', 'error', 'uncertainty', 'emergent_uncertainty', 'r_local', 'k_local', 'scale', 'bad_data', 'bad_data_reason']].to_string(index=False))

    print("\nTop 20 worst symbolic fits:")
    print(fitted_df_sorted.tail(20)[['name', 'value', 'unit', 'n', 'beta', 'approx', 'error', 'uncertainty', 'emergent_uncertainty', 'r_local', 'k_local', 'scale', 'bad_data', 'bad_data_reason']].to_string(index=False))

    print("\nPotentially bad data constants summary:")
    bad_data_df = fitted_df[fitted_df['bad_data'] == True][['name', 'value', 'error', 'rel_error', 'uncertainty', 'emergent_uncertainty', 'bad_data_reason']]
    print(bad_data_df.to_string(index=False))

    fitted_df_sorted.to_csv("symbolic_fit_results_emergent_optimized.txt", sep="\t", index=False)

    plt.figure(figsize=(10, 5))
    plt.hist(fitted_df_sorted['error'], bins=50, color='skyblue', edgecolor='black')
    plt.title('Histogram of Absolute Errors in Symbolic Fit')
    plt.xlabel('Absolute Error')
    plt.ylabel('Count')
    plt.grid(True)
    plt.tight_layout()
    plt.savefig("error_histogram.png")
    plt.close()

    plt.figure(figsize=(10, 5))
    plt.scatter(fitted_df_sorted['n'], fitted_df_sorted['error'], alpha=0.5, s=15, c='orange', edgecolors='black')
    plt.title('Absolute Error vs Symbolic Dimension n')
    plt.xlabel('n')
    plt.ylabel('Absolute Error')
    plt.grid(True)
    plt.tight_layout()
    plt.savefig("error_vs_n.png")
    plt.close()

    print(f"Total runtime: {time.time() - start_time:.2f} seconds. Check symbolic_fit.log for details.")